{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "1atQIthOqnop",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Train a model with Classical Training\n",
    "\n",
    "Although episodic training has attracted a lot of interest in the early years of Few-Shot Learning research, more recent works suggest that competitive results can be achieved with a simple cross entropy loss across all training classes. Therefore, it is becoming more and more common to use this classical process to train the backbone, that will be common to all methods compared at test time.\n",
    "\n",
    "This is in fact more representative of real use cases: episodic training assumes that, at training time, you have access to the shape of the few-shot tasks that will be encountered at test time (indeed you choose a specific number of ways for episodic training). You also \"force\" your inference method into the training of the network. Switching the few-shot learning logic to inference (i.e. no episodic training) allows methods to be agnostic of the backbone.\n",
    "\n",
    "Nonetheless, if you need to perform episodic training, we also provide [an example notebook](episodic_training.ipynb) for that."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "axL7cV71qnoz",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Getting started\n",
    "First we're going to do some imports (this is not the interesting part)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "MwkBm5_5tKw7",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "try:\n",
    "    import google.colab\n",
    "    colab = True\n",
    "except:\n",
    "    colab = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "-fnXeQdysyXW",
    "outputId": "9a5019b3-24c7-4cdb-9030-517c9a3ea524",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "if colab is True:\n",
    "    # Running in Google Colab\n",
    "    # Clone the repo\n",
    "    !git clone https://github.com/sicara/easy-few-shot-learning\n",
    "    %cd easy-few-shot-learning\n",
    "    !pip install .\n",
    "else:\n",
    "    # Run locally\n",
    "    # Ensure working directory is the project's root\n",
    "    # Make sure easyfsl is installed!\n",
    "    %cd .."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "gZ8bXCqpqno0",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import copy\n",
    "from pathlib import Path\n",
    "import random\n",
    "from statistics import mean\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "from torch import nn\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "50xGcwMrqno3",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Then we're gonna do the most important thing in Machine Learning research: ensuring reproducibility by setting the random seed. We're going to set the seed for all random packages that we could possibly use, plus some other stuff to make CUDA deterministic (see [here](https://pytorch.org/docs/stable/notes/randomness.html)).\n",
    "\n",
    "I strongly encourage that you do this in **all your scripts**."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "vpqXEWt8qno4",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "random_seed = 0\n",
    "np.random.seed(random_seed)\n",
    "torch.manual_seed(random_seed)\n",
    "random.seed(random_seed)\n",
    "torch.backends.cudnn.deterministic = True\n",
    "torch.backends.cudnn.benchmark = False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xW257QB4qno5",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Then we're gonna create our data loader for the training set. You can see that I chose tu use CUB in this notebook, because it’s a small dataset, so we can have good results quite quickly. I set a batch size of 128 but feel free to adapt it to your constraints.\n",
    "\n",
    "Note that we're not using the `TaskSampler` for the train data loader, because we won't be sampling training data in the shape of tasks as we would have in episodic training. We do it **normally**."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "L0qMlS_at91K",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# Download the CUB dataset\n",
    "!make download-cub"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "3kYxNc4tqno7",
    "outputId": "85a76b58-4267-44dc-85fd-e9fff2243563",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from easyfsl.datasets import CUB\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "batch_size = 128\n",
    "n_workers = 12\n",
    "\n",
    "train_set = CUB(split=\"train\", training=True)\n",
    "train_loader = DataLoader(\n",
    "    train_set,\n",
    "    batch_size=batch_size,\n",
    "    num_workers=n_workers,\n",
    "    pin_memory=True,\n",
    "    shuffle=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "G09f8Rkiqno8",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Now, we are going to create the model that we want to train. Here we choose the ResNet12 that is very often used in Few-Shot Learning research. Note that the default setting of these networks in EasyFSL is to not have a last fully connected layer (as it is usual for most Few-Shot Learning methods), but for classical training we need this layer! We also force it to output a vector which size is the number of different classes in the training set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "-zNE9ffbqnpA",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from easyfsl.modules import resnet12\n",
    "\n",
    "DEVICE = \"cuda\"\n",
    "\n",
    "model = resnet12(\n",
    "    use_fc=True,\n",
    "    num_classes=len(set(train_set.get_labels())),\n",
    ").to(DEVICE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "5uP_qztRqnpB",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Now, we still need validation ! Since we're training a model to perform few-shot classification, we will validate on few-shot tasks, so now we'll use the `TaskSampler`. We arbitrarily set the shape of the validation tasks. Ideally, you'd like to perform validation on various shapes of tasks, but we didn't implement this yet (feel free to contribute!).\n",
    "\n",
    "We also need to define the few-shot classification method that we will use during validation of the neural network we're training.\n",
    "Here we choose Prototypical Networks, because it's simple and efficient, but this is still an arbitrary choice."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "DrUD75fEqnpC",
    "outputId": "fb3cd669-7cbb-40c8-9059-3ba359c96ac9",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from easyfsl.methods import PrototypicalNetworks\n",
    "from easyfsl.samplers import TaskSampler\n",
    "\n",
    "n_way = 5\n",
    "n_shot = 5\n",
    "n_query = 10\n",
    "n_validation_tasks = 500\n",
    "\n",
    "val_set = CUB(split=\"val\", training=False)\n",
    "val_sampler = TaskSampler(\n",
    "    val_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_validation_tasks\n",
    ")\n",
    "val_loader = DataLoader(\n",
    "    val_set,\n",
    "    batch_sampler=val_sampler,\n",
    "    num_workers=n_workers,\n",
    "    pin_memory=True,\n",
    "    collate_fn=val_sampler.episodic_collate_fn,\n",
    ")\n",
    "\n",
    "few_shot_classifier = PrototypicalNetworks(model).to(DEVICE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "8DSkTh4zqnpD",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Training\n",
    "\n",
    "Now let's define our training helpers ! I chose to use Stochastic Gradient Descent on 200 epochs with a scheduler that divides the learning rate by 10 after 120 and 160 epochs. The strategy is derived from [this repo](https://github.com/fiveai/on-episodes-fsl).\n",
    "\n",
    "We're also gonna use a TensorBoard because it's always good to see what your training curves look like.\n",
    "\n",
    "An other thing: we're doing 200 epochs like in [the episodic training notebook](notebooks/episodic_training.ipynb), but keep in mind that an epoch in classical training means one pass through the 6000 images of the dataset, while in episodic training it's an arbitrary number of episodes. In the episodic training notebook an epoch is 500 episodes of 5-way, 5-shot, 10-query tasks, so 37500 images. TL;DR you may want to monitor your training and increase the number of epochs if necessary."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "59DHe5LBqnpE",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from torch.optim import SGD, Optimizer\n",
    "from torch.optim.lr_scheduler import MultiStepLR\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "\n",
    "\n",
    "LOSS_FUNCTION = nn.CrossEntropyLoss()\n",
    "\n",
    "n_epochs = 200\n",
    "scheduler_milestones = [150, 180]\n",
    "scheduler_gamma = 0.1\n",
    "learning_rate = 1e-01\n",
    "tb_logs_dir = Path(\".\")\n",
    "\n",
    "train_optimizer = SGD(\n",
    "    model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4\n",
    ")\n",
    "train_scheduler = MultiStepLR(\n",
    "    train_optimizer,\n",
    "    milestones=scheduler_milestones,\n",
    "    gamma=scheduler_gamma,\n",
    ")\n",
    "\n",
    "tb_writer = SummaryWriter(log_dir=str(tb_logs_dir))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "3RgyjdOUqnpF",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "And now let's get to it! Here we define the function that performs a training epoch.\n",
    "\n",
    "We use tqdm to monitor the training in real time in our logs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "NxMeTTCwqnpF",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "def training_epoch(model_: nn.Module, data_loader: DataLoader, optimizer: Optimizer):\n",
    "    all_loss = []\n",
    "    model_.train()\n",
    "    with tqdm(data_loader, total=len(data_loader), desc=\"Training\") as tqdm_train:\n",
    "        for images, labels in tqdm_train:\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "            loss = LOSS_FUNCTION(model_(images.to(DEVICE)), labels.to(DEVICE))\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            all_loss.append(loss.item())\n",
    "\n",
    "            tqdm_train.set_postfix(loss=mean(all_loss))\n",
    "\n",
    "    return mean(all_loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "y1P6GlRAqnpG",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "And we have everything we need! This is now the time to **start training**.\n",
    "\n",
    "A few notes:\n",
    "\n",
    "- We only validate every 10 epochs (you may set an even less frequent validation) because a training epoch is much faster than 500 few-shot tasks, and we don't want validation to be the bottleneck of our training process.\n",
    "\n",
    "- I also added something to log the state of the model that gave the best performance on the validation set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "QEm3sKqPqnpG",
    "outputId": "680e0012-3e7b-470d-f800-cbd75e41f773",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from easyfsl.utils import evaluate\n",
    "\n",
    "\n",
    "best_state = model.state_dict()\n",
    "best_validation_accuracy = 0.0\n",
    "validation_frequency = 10\n",
    "for epoch in range(n_epochs):\n",
    "    print(f\"Epoch {epoch}\")\n",
    "    average_loss = training_epoch(model, train_loader, train_optimizer)\n",
    "\n",
    "    if epoch % validation_frequency == validation_frequency - 1:\n",
    "\n",
    "        # We use this very convenient method from EasyFSL's ResNet to specify\n",
    "        # that the model shouldn't use its last fully connected layer during validation.\n",
    "        model.set_use_fc(False)\n",
    "        validation_accuracy = evaluate(\n",
    "            few_shot_classifier, val_loader, device=DEVICE, tqdm_prefix=\"Validation\"\n",
    "        )\n",
    "        model.set_use_fc(True)\n",
    "\n",
    "        if validation_accuracy > best_validation_accuracy:\n",
    "            best_validation_accuracy = validation_accuracy\n",
    "            best_state = copy.deepcopy(few_shot_classifier.state_dict())\n",
    "            # state_dict() returns a reference to the still evolving model's state so we deepcopy\n",
    "            # https://pytorch.org/tutorials/beginner/saving_loading_models\n",
    "            print(\"Ding ding ding! We found a new best model!\")\n",
    "\n",
    "        tb_writer.add_scalar(\"Val/acc\", validation_accuracy, epoch)\n",
    "\n",
    "    tb_writer.add_scalar(\"Train/loss\", average_loss, epoch)\n",
    "\n",
    "    # Warn the scheduler that we did an epoch\n",
    "    # so it knows when to decrease the learning rate\n",
    "    train_scheduler.step()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "cHWLcJ99qnpH",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Yay we successfully performed Classical Training! Now if you want to you can retrieve the best model's state."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "EG9bBTz4qnpH",
    "outputId": "f4aed221-a7fc-4563-d51c-d6606fd704ac",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "model.load_state_dict(best_state)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "l0iEJFIiqnpI",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Evaluation\n",
    "\n",
    "Now that our model is trained, we want to test it.\n",
    "\n",
    "First step: we fetch the test data. Note that we'll evaluate on the same shape of tasks as in validation. This is malicious practice, because it means that we used *a priori* information about the evaluation tasks during training. This is still less malicious than episodic training, though."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "eveON7-wqnpI",
    "outputId": "14f75db7-8395-4019-ef4a-73d8ce72bca9",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "n_test_tasks = 1000\n",
    "\n",
    "test_set = CUB(split=\"test\", training=False)\n",
    "test_sampler = TaskSampler(\n",
    "    test_set, n_way=n_way, n_shot=n_shot, n_query=n_query, n_tasks=n_test_tasks\n",
    ")\n",
    "test_loader = DataLoader(\n",
    "    test_set,\n",
    "    batch_sampler=test_sampler,\n",
    "    num_workers=n_workers,\n",
    "    pin_memory=True,\n",
    "    collate_fn=test_sampler.episodic_collate_fn,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Pr5aDnD-qnpJ",
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Second step: we instantiate a few-shot classifier using our trained ResNet as backbone, and run it on the test data. We keep using Prototypical Networks for consistence, but at this point you could basically use any few-shot classifier that takes no additional trainable parameters.\n",
    "\n",
    "Like we did during validation, we need to tell our ResNet to not use its last fully connected layer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "lK6WOtpwqnpJ",
    "outputId": "89f92fcd-c8b5-4bda-d1b7-294ccfaf19c0",
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "model.set_use_fc(False)\n",
    "\n",
    "accuracy = evaluate(few_shot_classifier, test_loader, device=DEVICE)\n",
    "print(f\"Average accuracy : {(100 * accuracy):.2f} %\")"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "id": "qe5fBbm4qnpJ",
    "pycharm": {
     "name": "#%% raw\n"
    }
   },
   "source": [
    "Congrats! You trained a network with cross entropy and used it with a few-shot learning method at test time."
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "classical_training.ipynb",
   "provenance": []
  },
  "gpuClass": "standard",
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}